

###################################
#### cell type analysis        ####


# perform QC for all cell types
# have an option to filter out suspect cell types
# have an option to perform gene filtering for all cell types

# data: just the normalized gx data
# celltype: names vector of unique cell type assignments.  Probably mostly NA.


#
# 
immune.cell.abundance.analysis = function(data,
                                          prb.annot,
                                          celltype,
                                          use.all.genes,
                                          lowcor,
                                          cell.type.pval.thresh,
                                          celltypes.arg.showraw,
                                          celltypes.arg.showrelative,
                                          celltypecontrasts,
                                          plot.against.covariates,
                                          covariates,
                                          covariate.type,
                                          annotcols,
                                          annotcols2,
                                          log,
                                          plottypearg,
                                          path.results,
                                          path.inc,
                                          path.to.csvs,
                                          path.to.celltype.results)
{
  # warnings
  warnings.paragraph <- ""
  
  # confirm that celltype is valid: (not sure how - it can take a lot of valid forms)
  # get the data into shape:
  #celltypes = setdiff(unique(celltypes),NA)
  # 
  # screen out cell types with no genes.  if use.all.genes == FALSE, also screen out cell types with only 1 gene:
  celltype.0genes = names(table(celltype))[table(celltype)==0]
  celltype.1genes = names(table(celltype))[table(celltype)==1]
  celltype[is.element(as.vector(celltype),celltype.0genes)] = NA
  #if(!use.all.genes){  celltype[is.element(as.vector(celltype),celltype.1genes)] = NA}
  if(!use.all.genes){  celltype[is.element(as.vector(celltype),celltype.0genes)] = NA} # changed 6-16
  
  # list of cell types
  celltypes = setdiff(unique(celltype),NA)
  
  ## results objects:
  celltypepvals = rep(NA,length(celltypes)); names(celltypepvals)=celltypes
  genesselected = vector("list",length(celltypes))
  names(genesselected) = celltypes
  celltypescores = matrix(NA,dim(data)[1],length(celltypes));
  dimnames(celltypescores)[[1]] = dimnames(data)[[1]]
  dimnames(celltypescores)[[2]] = celltypes
  
  ## for each cell type, select a subset of genes and calculate a p-value for their cell type-specific expression
  for(i in 1:length(celltypes))
  {
    ctype = celltypes[i]
    # get the cell type's genes
    cgenes = setdiff(names(celltype)[celltype==ctype],NA)
    cgenes = intersect(cgenes,dimnames(data)[[2]])
    # and isolate the data from them:
    cdata = data[,cgenes,drop=FALSE]
    
    ## now perform the gene selection protocol:
    if(!use.all.genes){genesselected[[i]] = celltype.selectgenes(cdata,lowcor)}
    if(use.all.genes){genesselected[[i]] = dimnames(cdata)[[2]]}
    #print(paste("genesselected[[i]]",genesselected[[i]]))
    
    ## now test whether the genes selected behave as a coherent score:
    #print(celltypes[i]); print(length(cgenes)); print(head(data[,genesselected[[i]]]))
    if(length(genesselected[[i]])<=1){celltypepvals[i]=1}
    if(length(genesselected[[i]])>1){celltypepvals[i] = test.cell.type.genes(data,genesselected[[i]],p0 = length(cgenes))}
    
    ## now score the samples:
    celltypescores[,i] = apply(data[,genesselected[[i]],drop=FALSE],1,mean)  
    
    ### QC plots of the genes:
    for(r in 1:length(plottypearg)){
      plottype=plottypearg[r];
      tempfilename = drawplot(paste(path.to.celltype.results,"//","QC for cell scores - ",ctype,sep=""),
                              plottype,width = 1.5,height=2)
      tempselected = intersect(dimnames(cdata)[[2]],genesselected[[i]])
      tempunselected = setdiff(dimnames(cdata)[[2]],genesselected[[i]])
      
      tempdat = cdata[,c(tempselected,tempunselected),drop=FALSE]
      dimnames(tempdat)[[2]] = paste(paste(prb.annot[colnames(tempdat),"Probe.Label"],"\n(",prb.annot[colnames(tempdat),"Analyte.Type"],")",sep="")
                                     ,c(rep("",length(tempselected)),rep("\n(discarded)",length(tempunselected))),sep="")
      # make it compatible with only having 1 gene:
      #plotcol = c("black","grey")[celltypepvals[i]<=cell.type.pval.thresh]
      plotcol="black"
      if(dim(tempdat)[2]>1)
      {
        pairs.adaptive.cex(tempdat,main = paste(ctype,"\np =",round(celltypepvals[i],2)),col=plotcol)#,col=annotcols[,j])
        #legend("top",legend = paste("p =",round(celltypepvals[i]),2))
      }
      if(dim(tempdat)[2]==1)
      {
        hist(as.vector(tempdat[,1]),main = ctype,xlab=dimnames(tempdat)[[2]][1],border=plotcol)
      }
      dev.off()}    
  }
  # apply the p-val thresh: remove non-significant cell types from all results objects
  not.sig = celltypepvals>cell.type.pval.thresh
  # single gene cell type categories are excluded from this thresholding:
  not.sig[celltype.1genes] = FALSE
  
  # and remove all celltypes with no genes selected:
  #print(genesselected)
  #no.genes = c(); for(i in 1:length(genesselected)){no.genes = length(genesselected[[i]])==0}
  no.genes = c(); for(i in 1:length(genesselected)){no.genes[i] = identical(NA,genesselected[[i]])}
  celltype.remove = not.sig|no.genes; names(celltype.remove)=celltypes
  #print(celltype.remove)
  celltypescores = celltypescores[,!celltype.remove,drop=FALSE]
  for(i in length(celltypes):1){if(celltype.remove[i]){genesselected[[i]] = NULL}}
  celltypes = celltypes[!celltype.remove]
  celltypepvals = celltypepvals[!celltype.remove]
  
  celltypescoreslist = list()
  celltypescoreslist[[1]] = celltypescores
  
  if(dim(celltypescores)[2]==0){print("Warning: no cell types passed the QC metrics.  Consider lowering the p-value threshold or lowering the expression threshold.")}
  
  # # Added an error where contrast file is cell types are invalid
  # if(any(!colnames(celltypescores) %in% rownames(celltypecontrasts))){
  #   stop("The cell types in the uploaded contrast matrix do not match the cell types in the probe annotations. Make sure the cell names are all in agreement.\n")
  # }
  
  if(dim(celltypescores)[2]>0)
  {
    ## get the adjusted scores (subtract the mean of the scores)
    #celltypescores.a = celltypescores - apply(celltypescores,1,mean)
    
    #### get the cell type score contrasts:
    extraneous.cell.types = setdiff(rownames(celltypecontrasts),colnames(celltypescores))
    contrasts.w.extraneous.cells = colnames(celltypecontrasts)[colSums(abs(celltypecontrasts)[extraneous.cell.types,])>0]
    celltypecontrasts = celltypecontrasts[setdiff(rownames(celltypecontrasts),extraneous.cell.types),setdiff(colnames(celltypecontrasts),contrasts.w.extraneous.cells),drop=FALSE]
    if(length(extraneous.cell.types)>0){
      wtmp <- paste("Warning: the following cell types in the cell types contrast file were not found in the probe annotation file's cell type column:\n",paste(extraneous.cell.types,collapse = ","),"\n")
      if(length(contrasts.w.extraneous.cells)>0)
        wtmp <- paste(wtmp,"Therefore the following contrasts have been discarded:\n",paste(contrasts.w.extraneous.cells,collapse = "\n"),"\n")
      warnings.paragraph <- paste(warnings.paragraph,wtmp)
      print(wtmp)
    }
    
    # added here 6-9:
    if(dim(celltypecontrasts)[2]==0)
    {
      wtmp <- "Warning: no cell type contrasts remained with valid data.  Only raw cell type abundance will be reported."
      print(wtmp)
      warnings.paragraph <- paste(warnings.paragraph,wtmp,"\n")
      celltypes.arg.showrelative = FALSE
    }
    if(dim(celltypecontrasts)[2]>0)
    {
      celltypes.intersect = intersect(colnames(celltypescores),rownames(celltypecontrasts))
      celltypescores.a = celltypescores[,celltypes.intersect]%*%as.matrix(celltypecontrasts[celltypes.intersect,,drop=FALSE])
      celltypescoreslist[[2]] = celltypescores.a
      names(celltypescoreslist)[2] = "relative"
    }
    names(celltypescoreslist)[1] = "raw"
    # added here 6-9:
    print(str(celltypescoreslist))
    
    ## now plot the results for each of raw and relative scores, as indicated:
    list.elements.to.plot = c(celltypes.arg.showraw,celltypes.arg.showrelative); names(list.elements.to.plot) = c("raw","relative")
    suffix = c("raw","relative")    #<--------------- in all plots below, need to change celltypescores to celltypescoreslist[t]], and change the name to use suffix[t]
    
    for(t in 1:2)
    {
      if(length(celltypescoreslist)<t)
        next
      if(list.elements.to.plot[t])
      { 
        # draw a histogram of the scores if there's only 1 cell type:
        if(ncol(celltypescoreslist[[t]])==1)
        {
          for(r in 1:length(plottypearg)){
            plottype=plottypearg[r];
            tempfilename = drawplot(filename=paste(path.to.celltype.results,"//heatmap of cell types scores - ",suffix[t],sep=""),plottype,width=1.5,height=1.5)
            tempfilename=gsub(path.results,"results",tempfilename)
            par(mar=c(10,4,2,1))
            #if(length(covariates)>0){ heatmap.3(celltypescoreslist[[t]],scale="column",symm=FALSE,Colv=TRUE,RowSideColors=annotcols[,colnames(covariates),drop=FALSE],margins=c(9,11),cexRow=1) }
            #if(length(covariates)==0){ heatmap.3(celltypescoreslist[[t]],scale="column",symm=FALSE,Colv=TRUE,margins=c(9,11),cexRow=1) }
            hist(celltypescoreslist[[t]][,1],xlab = colnames(celltypescoreslist[[t]])[1],main="")
            dev.off()}
        }
        
        # heatmap of cell type scores
        if(ncol(celltypescoreslist[[t]])>1)
        {
          for(r in 1:length(plottypearg)){
            plottype=plottypearg[r];
            tempfilename = drawplot(filename=paste(path.to.celltype.results,"//heatmap of cell types scores - ",suffix[t],sep=""),plottype,width=1.5,height=1.5)
            tempfilename=gsub(path.results,"results",tempfilename)
            par(mar=c(10,4,2,1))
            #if(dim(covariates)[2]>0){ heatmap.3(celltypescores,scale="none",symm=FALSE,Colv=TRUE,RowSideColors=rowcols,margins=c(9,11),cexRow=1) }
            #if(dim(covariates)[2]==0){ heatmap.3(celltypescores,scale="none",symm=FALSE,Colv=TRUE,margins=c(9,11),cexRow=1) }
            hmcols<-colorRampPalette(c("cornflowerblue","black","orange"))(256)
            if(length(covariates)>0){ heatmap.3(celltypescoreslist[[t]],scale="column",symm=FALSE,Colv=TRUE,RowSideColors=t(annotcols[,colnames(covariates),drop=FALSE]),margins=c(11,11),cexRow=1,col=hmcols) }
            if(length(covariates)==0){ heatmap.3(celltypescoreslist[[t]],scale="column",symm=FALSE,Colv=TRUE,margins=c(11,11),cexRow=1,col=hmcols) }
            dev.off()}
          
          # heatmap of cormat:
          for(r in 1:length(plottypearg)){
            plottype=plottypearg[r];
            tempfilename = drawplot(filename=paste(path.to.celltype.results,"//cell type scores correlation heatmap - ",suffix[t],sep=""),plottype,width=1.7,height=1.7,heatmapres=TRUE)
            tempfilename=gsub(path.results,"results",tempfilename)
            par(mar=c(10,4,2,1))
            hmcols<-colorRampPalette(c("blue","antiquewhite3","firebrick"))(256)
            hmcols = hmcols
            cormat2 = cor(celltypescoreslist[[t]])
            breaks = seq(-max(abs(cormat2)),max(abs(cormat2)),length.out=length(hmcols)+1)
            heatmap.2(cormat2,col=hmcols,symm=TRUE,Colv="Rowv",margins=c(8,8),cexRow=0.75,cexCol=0.75,trace="n",density="none", keysize = 1, key.xlab = "correlation")
            dev.off()}   
          
          ## pairs plots of cell type scores:
          ncelltypes = dim(celltypescoreslist[[t]])[2]
          # assign colors if there are any covariates selected to plot against:
          pairscols = list(); pairscols[[1]]="black"
          covariate.names.for.pairsplots = c("")
          #if(plot.against.covariates&(dim(covariates)[2]>0))
          if(plot.against.covariates&(length(covariates)>0))
          {
            for(i in 1:dim(covariates)[2])
            {
              pairscols[[i]] = annotcols[,i]
              covariate.names.for.pairsplots[i] = dimnames(covariates)[[2]][i]
            }
          }
          # draw pairs plots:
          for(i in 1:length(pairscols))
          {
            if(dim(celltypescoreslist[[t]])[2]==0)
              next
            for(k in 1:dim(celltypescoreslist[[t]])[2])
            {
              # unadjusted scores plot:
              for(r in 1:length(plottypearg)){
                plottype=plottypearg[r];
                celltypepairsplotsize = 1
                if(dim(celltypescoreslist[[t]])[2]>4){celltypepairsplotsize = 1.5}
                if(dim(celltypescoreslist[[t]])[2]>=16){celltypepairsplotsize = 2}
                if(nchar(covariate.names.for.pairsplots[1])>0){tempfilename = drawplot(paste(path.to.celltype.results,"//","cell scores pairs plot - ",suffix[t]," - ",dimnames(celltypescoreslist[[t]])[[2]][k]," - colored by ",covariate.names.for.pairsplots[i],sep=""),
                                                                                       plottype,width = celltypepairsplotsize,height=celltypepairsplotsize)}
                if(nchar(covariate.names.for.pairsplots[1])==0){tempfilename = drawplot(paste(path.to.celltype.results,"//","cell scores pairs plot - ",suffix[t]," - ",dimnames(celltypescoreslist[[t]])[[2]][k],sep=""),
                                                                                        plottype,width = celltypepairsplotsize,height=celltypepairsplotsize)}
                tempfilename=gsub(path.results,"results",tempfilename)
                
                par(mfrow=c(ceiling(sqrt(dim(celltypescoreslist[[t]])[2])),ceiling(sqrt(dim(celltypescoreslist[[t]])[2])))) 
                par(mar = c(4,6,3,0))
                for(j in setdiff(1:dim(celltypescoreslist[[t]])[2],k))
                {
                  plot(celltypescoreslist[[t]][,k],celltypescoreslist[[t]][,j],xlab=dimnames(celltypescoreslist[[t]])[[2]][k],ylab=dimnames(celltypescoreslist[[t]])[[2]][j],
                       col=pairscols[[i]],cex.lab=1.2,pch=16,cex=(1+(nrow(celltypescoreslist[[t]])<50)))
                  #write.csv(pairscols[[i]],file=paste(path.to.csvs,"//cell scores - pairs plot color",i,".csv",sep=""))   
                  # draw legend
                  ##if((j==3)&(dim(covariates)[2]>0))
                  #if((j==3)&(length(covariates)>0))
                  #{
                  #  if(covariate.type[i] == "categorical")
                  #  {
                  #    par(xpd=TRUE)
                  #    legend("topright",inset=c(0,-1),col=annotcols2[[i]],legend=levels(covariates[,i]),pch=1)    
                  #  }
                  #  if(covariate.type[i] == "continuous")
                  #  {
                  #    par(xpd=TRUE)
                  #    legend("topright",inset=c(0,-1),pch=c(NA,1,1),col=c("white",annotcols2[[i]][1],annotcols2[[i]][3]),legend=c(covariate.names.for.pairsplots[i],"low","high"))
                  #  }
                  #}
                }
                dev.off()}     
            }
          }
          ## pairs plots of PCs fit to pathway scores:
          # assign colors if there are any covariates selected to plot against:
          pairscols = list(); pairscols[[1]]="black"
          covariate.names.for.pairsplots = c("")
          #if(plot.against.covariates&(dim(covariates)[2]>0))
          if(plot.against.covariates&(length(covariates)>0))
          {
            for(i in 1:dim(covariates)[2])
            {
              pairscols[[i]] = annotcols[,i]
              covariate.names.for.pairsplots[i] = dimnames(covariates)[[2]][i]
            }
          }
          
          # unadjusted results:
          pc = prcomp(celltypescoreslist[[t]],scale=TRUE) 
          # draw pairs plots:
          for(i in 1:length(pairscols))
          {
            for(r in 1:length(plottypearg))
            {
              plottype=plottypearg[r];
              
              if(length(pairscols)>1){tempfilename = drawplot(paste(path.to.celltype.results,"//","principal components fit to celltype scores - pairs plot colored by ",covariate.names.for.pairsplots[i],sep=""),
                                                              plottype)}
              if(length(pairscols)==1){tempfilename = drawplot(paste(path.to.celltype.results,"//","principal components fit to celltype scores pairs plot",sep=""),
                                                               plottype)}
              tempfilename=gsub(path.results,"results",tempfilename)
              pch = 1+15*(nrow(pc$x)<100)
              pairs.adaptive.cex(pc$x[,1:min(4,dim(pc$x)[2])],col=pairscols[[i]],xaxt="n",yaxt="n",pch=pch)
              #write.csv(pairscols[[i]],file=paste(path.to.csvs,"//pathway scores - PC pairs plot color",i,".csv",sep=""))   
              dev.off()
            }
            
          }
        }# end if(dim(celltypescoreslist[[t]])[2]>1
        ### compare to covariates
        
        ## plot scores by covariates and save univariate pvals:
        pvals=NULL
        #ncelltypes = dim(celltypescores)[2]
        
        # loop through pathways:  
        if(plot.against.covariates&(length(covariates)>0))
        {
          # matrix of pvalues vs covariates:
          #pvals = matrix(NA,ncelltypes,dim(covariates)[2])
          pvals = matrix(NA,ncol(celltypescoreslist[[t]]),dim(covariates)[2])
          #dimnames(pvals)[[1]] = dimnames(celltypescoreslist[[t]])[[2]]
          #dimnames(pvals)[[2]] = dimnames(covariates)[[2]]
          dimnames(pvals) = list(dimnames(celltypescoreslist[[t]])[[2]],dimnames(covariates)[[2]])
          for(j in 1:ncol(celltypescoreslist[[t]]))
          {
            #loop through covariates
            for(i in 1:dim(covariates)[2])
            {
              # ID type of covariate
              if(covariate.type[i] == "continuous")
              {
                for(r in 1:length(plottypearg)){
                  plottype=plottypearg[r];
                  tempfilename = drawplot(paste(path.to.celltype.results,"//",names(covariates)[i]," - ",suffix[t]," - ",dimnames(celltypescoreslist[[t]])[[2]][j],sep=""),plottype)
                  tempfilename=gsub(path.results,"results",tempfilename)
                  if(r==1)
                  {  	 
                    strTemp=paste("document.write('              	    		  	   <li> Score vs ",names(covariates)[i],"<br><img src=\"",tempfilename,".png\"></li>');\n",sep="")
                    cat(strTemp,file=paste(path.inc,"//panel2_1.js",sep=""),append=TRUE)
                  }
                  plot(celltypescoreslist[[t]][,j]~covariates[,i],xlab = dimnames(covariates)[[2]][i],
                       ylab=paste(dimnames(celltypescoreslist[[t]])[[2]][j],"score"),cex.lab=1.5,cex=(1+(nrow(celltypescoreslist[[t]])<50)))
                  notmissing = !is.na(covariates[,i])
                  lines(lowess(covariates[notmissing,i],celltypescoreslist[[t]][notmissing,j]),col="grey25")
                  dev.off()}
                # and run hypothesis test:
                pvals[j,i] = summary(lm(celltypescoreslist[[t]][,j]~covariates[,i]))$coefficients[2,4]
              }
              if(covariate.type[i] == "categorical")
              {
                x = as.factor(covariates[,i])
                y = as.vector(celltypescoreslist[[t]][,j])
                for(r in 1:length(plottypearg)){
                  plottype=plottypearg[r];
                  tempfilename = drawplot(paste(path.to.celltype.results,"//",names(covariates)[i]," - ",suffix[t]," - ",dimnames(celltypescoreslist[[t]])[[2]][j],sep=""),plottype)
                  tempfilename=gsub(path.results,"results",tempfilename)
                  if(r==1)
                  {  
                    strTemp=paste("document.write('              	    		  	   <li> Score vs ",names(covariates)[i],"<br><img src=\"",tempfilename,".png\"></li>');\n",sep="")
                    cat(strTemp,file=paste(path.inc,"//panel2_1.js",sep=""),append=TRUE)
                  }
                  par(mar=c(10,6,2,1))
                  bp=boxplot(y~x,border="darkgrey",ylab=paste(dimnames(celltypescoreslist[[t]])[[2]][j],"score"),outline=FALSE,las=2,cex.lab=1.5)
                  points(y~jitter(as.numeric(x)),col=annotcols2[[i]][x],pch=16)
                  #lines(1:length(unique(x)),by(y,x,mean))
                  dev.off()}
                # and run hypothesis test:
                complete = !is.na(covariates[,i])
                pvals[j,i] = anova(lm(celltypescoreslist[[t]][complete,j]~covariates[complete,i]),lm(celltypescoreslist[[t]][complete,j]~1))[[6]][2]
                
              }
              ### trend plots vs covariates:
              if(j==1)
              {
                for(r in 1:length(plottypearg)){
                  plottype=plottypearg[r];
                  # tempfilename = drawplot(paste(path.to.celltype.results,"//","trend plot of cell type scores vs. ",covariate.names.for.pairsplots[i]," - ",suffix[t]," - legend",sep=""),
                  #                         plottype,width=2)
                  tempfilename = drawplot(paste(path.to.celltype.results,"//","trend plot of cell type scores vs. ",colnames(covariates)[i]," - ",suffix[t]," - legend",sep=""),
                                          plottype,width=2)
                  
                  trendplot(celltypescoreslist[[t]],
                            covariate=covariates[,i],
                            covariate.name=dimnames(covariates)[[2]][i],
                            covariate.type=covariate.type[i],
                            center=TRUE,
                            scale=FALSE,
                            ylab="Cell type scores",
                            lwd=2,
                            lty = rep(1:6,100)[1:ncol(celltypescoreslist[[t]])])
                  dev.off()} 
              }
            }
          }
          
        }
      } # end if(list.elements.to.plot[k])
      # write .csvs of cell type scores:
      write.csv(celltypescoreslist[[t]],file=paste(path.to.celltype.results,"//cell type scores - ",c("raw","relative")[t],".csv",sep=""))
    } 
  }
  out = list(celltypescoreslist=celltypescoreslist,celltypepvals=celltypepvals,genesselected=genesselected,warnings.paragraph=warnings.paragraph)
  return(out)
}

# pairwise dist btw two genes (analogous to correlation, but looking for a slopt between them of 1.)
celltypegene.pairwise.dist = function(x1,x2)
{
  sharedvar = mean(c(var(x1),var(x2)))
  return(sum((x1-mean(x1))*(x2-mean(x2)))/((length(x1)-1)*sharedvar))
}
# use above fn to get a distance matrix amongst genes:
celltypegene.dist = function(cdata)
{
  p = dim(cdata)[2]
  if(p>1)
  {
    dist = matrix(1,p,p)
    dimnames(dist)[[1]]=dimnames(dist)[[2]]=dimnames(cdata)[[2]]
    for(i in 2:p)
    {
      for(j in 1:(i-1))
      {
        dist[i,j]=dist[j,i]=celltypegene.pairwise.dist(cdata[,i],cdata[,j])
      }
    }
  }
  #if(p==1){dist = matrix(0,1,1)}
  return(dist)
}  

celltype.selectgenes = function(cdata,lowcor) #,bgthresh)
{
  # any genes in bg?
  #in.bg = colMeans(cdata<=bgthresh)>.3
  #cdata = cdata[,!in.bg,drop=FALSE]
  
  # look for genes with low correlations:
  p = dim(cdata)[2]
  p0=p
  total.low = 1000; iter=0
  genedist = celltypegene.dist(cdata)
  # iterations to remove genes with low cor with the rest:
  while(((total.low>0)&(iter<p+1))&(dim(cdata)[2]>2))
  {
    nlow = c()
    for(j in 1:dim(cdata)[2])
    {
      nlow[j] = sum(genedist[j,]<lowcor,na.rm=T)
    }
    to.remove = (1:dim(genedist)[1])[(nlow == max(nlow))&(nlow>0)]
    # only remove the one with the lowest total cor:
    if(length(to.remove)>1){to.remove = (to.remove)[rowSums(genedist[to.remove,,drop=FALSE],na.r=T)==min(rowSums(genedist[to.remove,,drop=FALSE],na.rm=T))]}
    if(length(to.remove)>0)
    {
      cdata = cdata[,-to.remove,drop=FALSE]
    }
    if(dim(cdata)[2]>1){genedist = celltypegene.dist(cdata)}
    if(dim(cdata)[2]==1){genedist = matrix(-Inf,1,1)}
    total.low = sum(genedist<lowcor)
    iter = iter+1
  }
  #if(dim(cdata)[2]>1){return(dimnames(cdata)[[2]])}
  #if(dim(cdata)[2]<=1){return(NA)}  
  return(dimnames(cdata)[[2]])    #<----- replaced above 2 lines on 6-16
}

# function to measure concordance amongst cell type genes:
score.celltype.gene.concordance = function(cdata)
{
  p = dim(cdata)[2]
  stat = (t(rep(sqrt(1/p),p)) %*% cov(cdata) %*% (rep(sqrt(1/p),p)))[1,1] / sum(diag(cov(cdata))) 
  return(stat)
}

# function to test the concordance of cell type genes:
test.cell.type.genes = function(data,genesselected,p0,B=1000)
{
  cdata = data[,intersect(dimnames(data)[[2]],genesselected)]
  p = length(genesselected)
  # get the stat:
  stat = score.celltype.gene.concordance(cdata)
  # get the permuted data stats:
  bstats = c()
  for(b in 1:B)
  {
    # select p0 genes (the number of original candidates:
    tdata = data[,sample(1:dim(data)[2],p0),drop=FALSE]
    # and take the p genes with the best concordance:
    tdist = celltypegene.dist(tdata)
    best = order(rowMeans(tdist,na.rm=T),decreasing=T)[1:p]
    tdata = tdata[,best]
    # measure their concordance:
    bstats[b] = score.celltype.gene.concordance(tdata)
  }
  return(mean(bstats>stat))
}
